{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#基础索引类型\n",
    "##数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "d = 512          #维数\n",
    "n_data = 2000   \n",
    "np.random.seed(0) \n",
    "data = []\n",
    "mu = 3\n",
    "sigma = 0.1\n",
    "for i in range(n_data):\n",
    "    data.append(np.random.normal(mu, sigma, d))\n",
    "data = np.array(data).astype('float32')\n",
    "\n",
    "#query\n",
    "query = []\n",
    "n_query = 10\n",
    "np.random.seed(12) \n",
    "query = []\n",
    "for i in range(n_query):\n",
    "    query.append(np.random.normal(mu, sigma, d))\n",
    "query = np.array(query).astype('float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入faiss\n",
    "import sys\n",
    "sys.path.append('/home/maliqi/faiss/python/')\n",
    "import faiss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##1.精确搜索（Exact Search for L2）\n",
    "一种暴力搜索方法，遍历数据库中的每一个向量与查询向量对比。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8.61838   8.782156  8.782816  8.832029  8.837633  8.848496  8.897978\n",
      "  8.916636  8.919006  8.9374   ]\n",
      " [9.033303  9.038907  9.091705  9.15584   9.164591  9.200112  9.201884\n",
      "  9.220335  9.279477  9.312859 ]\n",
      " [8.063818  8.211029  8.306456  8.373352  8.459253  8.459892  8.498557\n",
      "  8.546464  8.555408  8.621426 ]\n",
      " [8.193894  8.211956  8.34701   8.446963  8.45299   8.45486   8.473572\n",
      "  8.50477   8.513636  8.530684 ]\n",
      " [8.369624  8.549444  8.704066  8.736764  8.760082  8.777319  8.831345\n",
      "  8.835486  8.858271  8.860058 ]\n",
      " [8.299072  8.432398  8.434382  8.457374  8.539217  8.562359  8.579033\n",
      "  8.618736  8.630861  8.643393 ]\n",
      " [8.615004  8.615164  8.72604   8.730943  8.762621  8.796932  8.797068\n",
      "  8.797365  8.813985  8.834726 ]\n",
      " [8.377227  8.522776  8.711159  8.724562  8.745737  8.763846  8.768602\n",
      "  8.7727995 8.786856  8.828224 ]\n",
      " [8.342917  8.488056  8.655106  8.662771  8.701336  8.741287  8.743608\n",
      "  8.770507  8.786264  8.849051 ]\n",
      " [8.522164  8.575703  8.68462   8.767247  8.782909  8.850494  8.883733\n",
      "  8.90369   8.909393  8.91768  ]]\n"
     ]
    }
   ],
   "source": [
    "index = faiss.IndexFlatL2(d)\n",
    "# index = faiss.index_factory(d, \"Flat\") #两种定义方式\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##2.精确搜索（Exact Search for Inner Product）\n",
    "当数据库向量是标准化的，计算返回的distance就是余弦相似度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4621.749  4621.5464 4619.745  4619.381  4619.177  4618.0615 4617.169\n",
      "  4617.0566 4617.0483 4616.631 ]\n",
      " [4637.3975 4637.288  4635.368  4635.2446 4634.881  4633.608  4633.0215\n",
      "  4632.7637 4632.56   4632.373 ]\n",
      " [4621.756  4621.4697 4619.7485 4619.5615 4619.424  4618.0186 4616.9927\n",
      "  4616.962  4616.901  4616.735 ]\n",
      " [4623.6074 4623.5596 4621.3965 4621.158  4620.906  4619.838  4618.9756\n",
      "  4618.9126 4618.7695 4618.478 ]\n",
      " [4625.553  4625.0645 4623.461  4623.196  4622.957  4621.337  4620.7373\n",
      "  4620.717  4620.5635 4620.2485]\n",
      " [4628.489  4628.449  4626.491  4626.487  4625.6406 4624.6143 4624.29\n",
      "  4624.     4623.7524 4623.618 ]\n",
      " [4637.7466 4637.338  4635.3047 4635.125  4634.748  4633.0137 4632.864\n",
      "  4632.58   4632.3027 4632.2324]\n",
      " [4630.472  4630.333  4628.264  4627.9375 4627.738  4626.8965 4625.814\n",
      "  4625.7227 4625.4443 4625.091 ]\n",
      " [4635.7715 4635.489  4633.6904 4633.568  4632.658  4631.463  4631.4307\n",
      "  4631.101  4630.99   4630.3066]\n",
      " [4625.6753 4625.558  4623.454  4623.3926 4623.324  4622.2827 4621.7783\n",
      "  4621.1157 4620.905  4620.854 ]]\n"
     ]
    }
   ],
   "source": [
    "index = faiss.IndexFlatIP(d)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##3.（Hierarchical Navigable Small World graph exploration）\n",
    "返回近似结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8.61838   8.832029  8.848496  8.897978  8.916636  8.9374    8.9597\n",
      "  8.962785  8.984709  8.998907 ]\n",
      " [9.038907  9.164591  9.200112  9.201884  9.220335  9.312859  9.34434\n",
      "  9.344851  9.416974  9.421429 ]\n",
      " [8.306456  8.373352  8.459253  8.546464  8.631898  8.63715   8.63917\n",
      "  8.713682  8.735945  8.7704735]\n",
      " [8.193894  8.211956  8.34701   8.45486   8.473572  8.50477   8.513636\n",
      "  8.530684  8.545482  8.617173 ]\n",
      " [8.369624  8.760082  8.831345  8.858271  8.860058  8.862642  8.936951\n",
      "  8.996922  8.998444  9.022133 ]\n",
      " [8.299072  8.432398  8.434382  8.539217  8.562359  8.698317  8.753672\n",
      "  8.768751  8.779131  8.780444 ]\n",
      " [8.615004  8.615164  8.730943  8.797365  8.861536  8.885755  8.911812\n",
      "  8.922768  8.942963  8.980488 ]\n",
      " [8.377227  8.522776  8.711159  8.724562  8.745737  8.768602  8.7727995\n",
      "  8.786856  8.828224  8.879469 ]\n",
      " [8.342917  8.488056  8.662771  8.741287  8.743608  8.770507  8.857255\n",
      "  8.893716  8.932134  8.933593 ]\n",
      " [8.575703  8.68462   8.850494  8.883733  8.90369   8.909393  8.91768\n",
      "  8.936615  8.961668  8.977329 ]]\n"
     ]
    }
   ],
   "source": [
    "index = faiss.IndexHNSWFlat(d,16)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##4.倒排表搜索（Inverted file with exact post-verification）\n",
    "快速入门部分介绍过。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8.837633  9.122337  9.217627  9.362019  9.39345   9.396795  9.401556\n",
      "  9.446939  9.52043   9.5279255]\n",
      " [9.436286  9.636714  9.707813  9.714355  9.734249  9.809814  9.87722\n",
      "  9.960412  9.978079  9.982276 ]\n",
      " [8.621426  8.658703  8.842339  8.862192  8.891519  8.937078  8.972767\n",
      "  8.98658   9.007745  9.088661 ]\n",
      " [8.211956  8.735372  8.747662  8.800873  8.917062  9.1208725 9.178852\n",
      "  9.215968  9.2192    9.265095 ]\n",
      " [8.858271  8.998444  9.041813  9.0883045 9.159481  9.169218  9.187948\n",
      "  9.203735  9.204121  9.256811 ]\n",
      " [8.434382  8.539217  8.630861  8.753672  8.768751  8.794859  8.815165\n",
      "  8.817884  8.8404    8.848925 ]\n",
      " [8.861536  8.878873  8.942963  8.944212  8.9446945 8.95914   8.980488\n",
      "  9.051479  9.059914  9.081419 ]\n",
      " [9.15522   9.423113  9.432117  9.465836  9.529045  9.554071  9.556268\n",
      "  9.638275  9.656209  9.69151  ]\n",
      " [8.743608  8.902418  9.065649  9.201052  9.223066  9.223073  9.247414\n",
      "  9.269661  9.288244  9.291237 ]\n",
      " [8.936615  9.077     9.152468  9.1537075 9.313195  9.314999  9.373196\n",
      "  9.400535  9.434517  9.445862 ]]\n"
     ]
    }
   ],
   "source": [
    "nlist = 50\n",
    "quantizer = faiss.IndexFlatL2(d)\n",
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)\n",
    "index.train(data)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##5.LSH（Locality-Sensitive Hashing (binary flat index)）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.]\n",
      " [ 7.  8.  9.  9.  9. 10. 10. 10. 10. 10.]\n",
      " [ 7.  8.  8.  9.  9.  9.  9.  9.  9.  9.]\n",
      " [ 9.  9. 10. 11. 12. 12. 12. 12. 12. 12.]\n",
      " [ 6.  6.  6.  7.  7.  8.  8.  8.  8.  8.]\n",
      " [ 8.  8.  8.  9.  9.  9.  9.  9. 10. 10.]\n",
      " [ 6.  7.  8.  8.  9.  9.  9.  9.  9.  9.]\n",
      " [ 9.  9.  9.  9.  9.  9.  9.  9.  9. 10.]\n",
      " [ 7.  8.  8.  8.  8.  8.  8.  9.  9.  9.]\n",
      " [ 9.  9.  9. 10. 10. 10. 10. 10. 10. 10.]]\n"
     ]
    }
   ],
   "source": [
    "nbits = 2 * d\n",
    "index = faiss.IndexLSH(d, nbits)\n",
    "index.train(data)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##6.SQ量化（Scalar quantizer (SQ) in flat mode）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8.623227  8.777792  8.785317  8.828824  8.83549   8.845292  8.896896\n",
      "  8.914818  8.922382  8.934983 ]\n",
      " [9.028506  9.037546  9.099248  9.1526165 9.16542   9.19639   9.200499\n",
      "  9.224975  9.274046  9.3053875]\n",
      " [8.064029  8.21301   8.310526  8.376435  8.457833  8.462002  8.501087\n",
      "  8.550647  8.556992  8.624525 ]\n",
      " [8.19665   8.210531  8.346436  8.444769  8.452809  8.454114  8.4745245\n",
      "  8.496618  8.510042  8.525612 ]\n",
      " [8.370452  8.547959  8.704323  8.733619  8.763926  8.776738  8.829511\n",
      "  8.835644  8.857149  8.859046 ]\n",
      " [8.29591   8.432422  8.435944  8.454732  8.542395  8.565367  8.579683\n",
      "  8.621871  8.632034  8.644775 ]\n",
      " [8.609016  8.612934  8.72663   8.734133  8.758857  8.797326  8.797966\n",
      "  8.798654  8.815295  8.8382225]\n",
      " [8.378947  8.521084  8.711153  8.726161  8.748383  8.759655  8.768218\n",
      "  8.769182  8.792372  8.834644 ]\n",
      " [8.340463  8.48951   8.659344  8.664954  8.702756  8.741513  8.741941\n",
      "  8.768993  8.781276  8.852154 ]\n",
      " [8.520282  8.574987  8.683459  8.769213  8.7820425 8.85128   8.881118\n",
      "  8.906741  8.907756  8.924014 ]]\n"
     ]
    }
   ],
   "source": [
    "index = faiss.IndexScalarQuantizer(d, 4)\n",
    "index.train(data)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##7.PQ量化（Product quantizer (PQ) in flat mode）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[5.3184814 5.33667   5.3638916 5.366333  5.3704834 5.4000244 5.404663\n",
      "  5.415283  5.425659  5.427246 ]\n",
      " [5.6835938 5.686035  5.687134  5.7489014 5.76062   5.7731934 5.7766113\n",
      "  5.7875977 5.798828  5.7991943]\n",
      " [4.902588  5.0057373 5.0323486 5.036255  5.045044  5.048828  5.0498047\n",
      "  5.0499268 5.072998  5.0737305]\n",
      " [4.844116  4.850586  4.868042  4.8946533 4.8997803 4.8999023 4.902954\n",
      "  4.909546  4.9210205 4.921875 ]\n",
      " [5.279419  5.333252  5.3344727 5.3431396 5.35083   5.357422  5.366211\n",
      "  5.3862305 5.38855   5.3936768]\n",
      " [5.019409  5.048706  5.0942383 5.1052246 5.116455  5.157593  5.159424\n",
      "  5.168457  5.171875  5.194092 ]\n",
      " [5.0563965 5.0909424 5.1367188 5.1534424 5.1724854 5.199951  5.20105\n",
      "  5.2144775 5.214966  5.23938  ]\n",
      " [5.16333   5.173706  5.2418213 5.265259  5.265869  5.274414  5.291382\n",
      "  5.307495  5.309204  5.310425 ]\n",
      " [5.1501465 5.2508545 5.291992  5.3186035 5.3205566 5.328369  5.336548\n",
      "  5.3479004 5.35376   5.360962 ]\n",
      " [5.2751465 5.2772217 5.279663  5.3304443 5.350708  5.3571777 5.3669434\n",
      "  5.373047  5.373413  5.382324 ]]\n"
     ]
    }
   ],
   "source": [
    "M = 8 #必须是d的因数\n",
    "nbits = 6  #只能是8， 12， 16\n",
    "index = faiss.IndexPQ(d, M, nbits)\n",
    "index.train(data)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##8.倒排表乘积量化（IVFADC (coarse quantizer+PQ on residuals)）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[5.1985765 5.209732  5.233874  5.237282  5.2553835 5.262968  5.270462\n",
      "  5.2895284 5.2908745 5.302353 ]\n",
      " [5.5696826 5.5942397 5.611737  5.6186624 5.619787  5.643144  5.646076\n",
      "  5.676093  5.682111  5.6982036]\n",
      " [4.7446747 4.824335  4.834736  4.844829  4.850663  4.853364  4.867393\n",
      "  4.873641  4.8785725 4.88787  ]\n",
      " [4.783175  4.797909  4.8491716 4.85687   4.857151  4.8586845 4.860058\n",
      "  4.866444  4.868099  4.885188 ]\n",
      " [5.1260395 5.134188  5.1386065 5.141901  5.1756086 5.192538  5.1938267\n",
      "  5.1975694 5.199704  5.2012296]\n",
      " [4.882325  4.900981  4.9040375 4.911916  4.916094  4.923492  4.928433\n",
      "  4.928472  4.937878  4.95728  ]\n",
      " [4.9729834 4.976016  4.984484  5.0074816 5.0200887 5.0217285 5.029479\n",
      "  5.029899  5.0346465 5.0349855]\n",
      " [5.1357193 5.147153  5.1525207 5.189519  5.217377  5.220489  5.2341766\n",
      "  5.239973  5.2411985 5.253551 ]\n",
      " [5.0623484 5.087064  5.1075807 5.109309  5.110051  5.1330123 5.1387715\n",
      "  5.1431603 5.151037  5.1516275]\n",
      " [5.12455   5.163775  5.1762547 5.185327  5.190364  5.19723   5.2099175\n",
      "  5.2115583 5.214532  5.2182474]]\n"
     ]
    }
   ],
   "source": [
    "M = 8\n",
    "nbits = 4\n",
    "nlist = 50\n",
    "quantizer = faiss.IndexFlatL2(d)\n",
    "index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)\n",
    "index.train(data)\n",
    "index.add(data)\n",
    "dis, ind = index.search(query, 10)\n",
    "print(dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#cell-probe方法\n",
    "为了加速索引过程，经常采用划分子类空间（如k-means）的方法，虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间，再在部分子空间中搜索的方法，就是cell-probe方法。  \n",
    "具体流程为：  \n",
    "    1）数据集空间被划分为n个部分，在k-means中，表现为n个类；  \n",
    "    2）每个类中的向量保存在一个倒排表中，共有n个倒排表；  \n",
    "    3）查询时，选中nprobe个倒排表；  \n",
    "    4）将这几个倒排表中的向量与查询向量作对比。  \n",
    "在这种方法中，只需要排查数据库中的一部分向量，大约只有nprobe/n的数据，因为每个倒排表的长度并不一致（每个类中的向量个数不一定相等）。  \n",
    "\n",
    "#cell-probe粗量化\n",
    "在一些索引类型中，需要一个Flat index作为粗量化器，如IndexIVFFlat,在训练的时候会将类中心保存在Flat index中，在add和search阶段，会首先判定将其落入哪个类空间。在search阶段，nprobe参数需要调整以权衡检索精度与检索速度。  \n",
    "实验表明，对高维数据，需要维持比较高的nprobe数值才能保证精度。\n",
    "\n",
    "#与LSH的优劣\n",
    "LSH也是一种cell-probe方法，与其相比，LSH有一下一点不足：  \n",
    "1）LSH需要大量的哈希方程，会带来额外的内存开销；  \n",
    "2）哈希函数不适合输入数据。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
